-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Track entropy and MI of routing distribution for topk MoE #188
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idea is good, thanks @oleksost.
bit weird that all these metrics are appearing as losses. that name should be reserved for things for which gradients are computed. just call this dict metrics?
Yes @tscholak, addressed. Using metrics dict instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, got some comments on the structure.
@@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: | |||
@abc.abstractmethod | |||
def loss_defs(self) -> list[LossDef]: | |||
pass | |||
|
|||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric
flag in LossDef
(or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step
@@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | |||
for name, reduced_loss in reduced_losses.items() | |||
} | |||
|
|||
def _is_reduced_metric(self, metric_name: str) -> bool: | |||
"""Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" | |||
from fast_llm.layers.transformer.config import TransformerReducedMetrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.
|
||
|
||
# Store these metrics | ||
if metrics is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the extra computation involved, this should be enabled through a config parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much compute are we talking about for these metrics? likely this won't be noticeable.
assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" | ||
|
||
|
||
def test_edge_cases(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More explicit name?
|
||
|
||
@pytest.fixture | ||
def setup_runner(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't belong here. How about test_runner.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't they belong here? this is fixture is only useful for the tests in this suite.
|
||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed
@@ -26,6 +27,35 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could try @torch.compile
on these for a free performance boost.
average_entropy = entropy_values.mean() # Average over batch and tokens | ||
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) | ||
|
||
def entropy(probs: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calculate_entropy
✨ Description
To better detect potential routing collapse and have a better understanding about the routing distribution, we can track the average entropy and mutual information of routing probabilities.
Collapse routing would have low entropy and low mutual information. A healthy and specialised router would have low entropy and high mutual information, meaning that routing is specialised and considerably different across tokens.
More specifically:
Mutual info. measures the difference between:
🔍 Type of change
Select all that apply:
📝 Changes
mixture_of_experts.py
, they are calculated only for the topk routing type.✅ Checklist
General
Testing
Performance Impact
📊 Performance Impact Details
I am not 100% sure there is no performance impact, we are calculating the stats at each forward pass through the router.
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.